1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *     http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.apache.lucene.classification;
18  
19  import org.apache.lucene.analysis.Analyzer;
20  import org.apache.lucene.analysis.TokenStream;
21  import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
22  import org.apache.lucene.document.Document;
23  import org.apache.lucene.index.LeafReader;
24  import org.apache.lucene.index.IndexableField;
25  import org.apache.lucene.index.MultiFields;
26  import org.apache.lucene.index.Term;
27  import org.apache.lucene.index.Terms;
28  import org.apache.lucene.index.TermsEnum;
29  import org.apache.lucene.search.BooleanClause;
30  import org.apache.lucene.search.BooleanQuery;
31  import org.apache.lucene.search.IndexSearcher;
32  import org.apache.lucene.search.Query;
33  import org.apache.lucene.search.ScoreDoc;
34  import org.apache.lucene.search.WildcardQuery;
35  import org.apache.lucene.util.BytesRef;
36  import org.apache.lucene.util.BytesRefBuilder;
37  import org.apache.lucene.util.IntsRefBuilder;
38  import org.apache.lucene.util.fst.Builder;
39  import org.apache.lucene.util.fst.FST;
40  import org.apache.lucene.util.fst.PositiveIntOutputs;
41  import org.apache.lucene.util.fst.Util;
42  
43  import java.io.IOException;
44  import java.util.List;
45  import java.util.Map;
46  import java.util.SortedMap;
47  import java.util.TreeMap;
48  
49  /**
50   * A perceptron (see <code>http://en.wikipedia.org/wiki/Perceptron</code>) based
51   * <code>Boolean</code> {@link org.apache.lucene.classification.Classifier}. The
52   * weights are calculated using
53   * {@link org.apache.lucene.index.TermsEnum#totalTermFreq} both on a per field
54   * and a per document basis and then a corresponding
55   * {@link org.apache.lucene.util.fst.FST} is used for class assignment.
56   * 
57   * @lucene.experimental
58   */
59  public class BooleanPerceptronClassifier implements Classifier<Boolean> {
60  
61    private Double threshold;
62    private final Integer batchSize;
63    private Terms textTerms;
64    private Analyzer analyzer;
65    private String textFieldName;
66    private FST<Long> fst;
67  
68    /**
69     * Create a {@link BooleanPerceptronClassifier}
70     * 
71     * @param threshold
72     *          the binary threshold for perceptron output evaluation
73     */
74    public BooleanPerceptronClassifier(Double threshold, Integer batchSize) {
75      this.threshold = threshold;
76      this.batchSize = batchSize;
77    }
78  
79    /**
80     * Default constructor, no batch updates of FST, perceptron threshold is
81     * calculated via underlying index metrics during
82     * {@link #train(org.apache.lucene.index.LeafReader, String, String, org.apache.lucene.analysis.Analyzer)
83     * training}
84     */
85    public BooleanPerceptronClassifier() {
86      batchSize = 1;
87    }
88  
89    /**
90     * {@inheritDoc}
91     */
92    @Override
93    public ClassificationResult<Boolean> assignClass(String text)
94        throws IOException {
95      if (textTerms == null) {
96        throw new IOException("You must first call Classifier#train");
97      }
98      Long output = 0l;
99      try (TokenStream tokenStream = analyzer.tokenStream(textFieldName, text)) {
100       CharTermAttribute charTermAttribute = tokenStream
101         .addAttribute(CharTermAttribute.class);
102       tokenStream.reset();
103       while (tokenStream.incrementToken()) {
104         String s = charTermAttribute.toString();
105         Long d = Util.get(fst, new BytesRef(s));
106         if (d != null) {
107           output += d;
108         }
109       }
110       tokenStream.end();
111     }
112 
113     return new ClassificationResult<>(output >= threshold, output.doubleValue());
114   }
115 
116   /**
117    * {@inheritDoc}
118    */
119   @Override
120   public void train(LeafReader leafReader, String textFieldName,
121                     String classFieldName, Analyzer analyzer) throws IOException {
122     train(leafReader, textFieldName, classFieldName, analyzer, null);
123   }
124 
125   /**
126    * {@inheritDoc}
127    */
128   @Override
129   public void train(LeafReader leafReader, String textFieldName,
130       String classFieldName, Analyzer analyzer, Query query) throws IOException {
131     this.textTerms = MultiFields.getTerms(leafReader, textFieldName);
132 
133     if (textTerms == null) {
134       throw new IOException("term vectors need to be available for field " + textFieldName);
135     }
136 
137     this.analyzer = analyzer;
138     this.textFieldName = textFieldName;
139 
140     if (threshold == null || threshold == 0d) {
141       // automatic assign a threshold
142       long sumDocFreq = leafReader.getSumDocFreq(textFieldName);
143       if (sumDocFreq != -1) {
144         this.threshold = (double) sumDocFreq / 2d;
145       } else {
146         throw new IOException(
147             "threshold cannot be assigned since term vectors for field "
148                 + textFieldName + " do not exist");
149       }
150     }
151 
152     // TODO : remove this map as soon as we have a writable FST
153     SortedMap<String,Double> weights = new TreeMap<>();
154 
155     TermsEnum termsEnum = textTerms.iterator();
156     BytesRef textTerm;
157     while ((textTerm = termsEnum.next()) != null) {
158       weights.put(textTerm.utf8ToString(), (double) termsEnum.totalTermFreq());
159     }
160     updateFST(weights);
161 
162     IndexSearcher indexSearcher = new IndexSearcher(leafReader);
163 
164     int batchCount = 0;
165 
166     BooleanQuery.Builder q = new BooleanQuery.Builder();
167     q.add(new BooleanClause(new WildcardQuery(new Term(classFieldName, "*")), BooleanClause.Occur.MUST));
168     if (query != null) {
169       q.add(new BooleanClause(query, BooleanClause.Occur.MUST));
170     }
171     // run the search and use stored field values
172     for (ScoreDoc scoreDoc : indexSearcher.search(q.build(),
173         Integer.MAX_VALUE).scoreDocs) {
174       Document doc = indexSearcher.doc(scoreDoc.doc);
175 
176       IndexableField textField = doc.getField(textFieldName);
177       
178       // get the expected result
179       IndexableField classField = doc.getField(classFieldName);
180 
181       if (textField != null && classField != null) {
182         // assign class to the doc
183         ClassificationResult<Boolean> classificationResult = assignClass(textField.stringValue());
184         Boolean assignedClass = classificationResult.getAssignedClass();
185 
186         Boolean correctClass = Boolean.valueOf(classField.stringValue());
187         long modifier = correctClass.compareTo(assignedClass);
188         if (modifier != 0) {
189           updateWeights(leafReader, scoreDoc.doc, assignedClass,
190                 weights, modifier, batchCount % batchSize == 0);
191         }
192         batchCount++;
193       }
194     }
195     weights.clear(); // free memory while waiting for GC
196   }
197 
198   @Override
199   public void train(LeafReader leafReader, String[] textFieldNames, String classFieldName, Analyzer analyzer, Query query) throws IOException {
200     throw new IOException("training with multiple fields not supported by boolean perceptron classifier");
201   }
202 
203   private void updateWeights(LeafReader leafReader,
204                              int docId, Boolean assignedClass, SortedMap<String, Double> weights,
205                              double modifier, boolean updateFST) throws IOException {
206     TermsEnum cte = textTerms.iterator();
207 
208     // get the doc term vectors
209     Terms terms = leafReader.getTermVector(docId, textFieldName);
210 
211     if (terms == null) {
212       throw new IOException("term vectors must be stored for field "
213           + textFieldName);
214     }
215 
216     TermsEnum termsEnum = terms.iterator();
217 
218     BytesRef term;
219 
220     while ((term = termsEnum.next()) != null) {
221       cte.seekExact(term);
222       if (assignedClass != null) {
223         long termFreqLocal = termsEnum.totalTermFreq();
224         // update weights
225         Long previousValue = Util.get(fst, term);
226         String termString = term.utf8ToString();
227         weights.put(termString, previousValue + modifier * termFreqLocal);
228       }
229     }
230     if (updateFST) {
231       updateFST(weights);
232     }
233   }
234 
235   private void updateFST(SortedMap<String,Double> weights) throws IOException {
236     PositiveIntOutputs outputs = PositiveIntOutputs.getSingleton();
237     Builder<Long> fstBuilder = new Builder<>(FST.INPUT_TYPE.BYTE1, outputs);
238     BytesRefBuilder scratchBytes = new BytesRefBuilder();
239     IntsRefBuilder scratchInts = new IntsRefBuilder();
240     for (Map.Entry<String,Double> entry : weights.entrySet()) {
241       scratchBytes.copyChars(entry.getKey());
242       fstBuilder.add(Util.toIntsRef(scratchBytes.get(), scratchInts), entry
243           .getValue().longValue());
244     }
245     fst = fstBuilder.finish();
246   }
247 
248   /**
249    * {@inheritDoc}
250    */
251   @Override
252   public List<ClassificationResult<Boolean>> getClasses(String text)
253       throws IOException {
254     throw new RuntimeException("not implemented");
255   }
256 
257   /**
258    * {@inheritDoc}
259    */
260   @Override
261   public List<ClassificationResult<Boolean>> getClasses(String text, int max)
262       throws IOException {
263     throw new RuntimeException("not implemented");
264   }
265 
266 }